Skip to content

server: implement GLM-style MTP #15225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft

Conversation

F1LM1
Copy link

@F1LM1 F1LM1 commented Aug 11, 2025

This is very much a draft/proof of concept I'm playing with, just one idea for an MTP implementation. Planning to test on GLM-4.5 because it's the only model out there that we've preserved NextN tensors for.

From what I can tell

  • the three models with MTP implemented in vLLM right now are all "DeepseekV3-style,"
  • they only have one MTP head, which predicts token at position n+2,
  • the MTP layers take as input the output embedding from the last conventional layer and their own input embedding.

So implementation-wise it seems like

  • we should try to reuse the existing speculative decode functionality (including nice stuff like main model KV cache management, various samplers, etc.),
  • but a lot of the full draft model functionality is redundant/harmful, like context/cache management for the draft model, vocab matching,
  • it probably makes sense to write a new function like mtp_speculative_gen_draft in speculative.cpp that is vastly simplified and branch into it in server.cpp when a slot has MTP (versus common_speculative_gen_draft).
  • AFAICT it looks like the server.cpp loop currently alternates between conventional forward pass and draft, which in the MTP case will probably sabotage performance gains (since our max throughput is only 1.5 tok/pass assuming zero rejections, instead of 2 tok/pass). Let me know if this isn't the case!—but if it is, should probably avoid doing non-speculative decodes after the first response token.
  • It doesn't make sense to have to manage a distinct ctx_dft in this case as well. It's a bit hacky but I was thinking we could just have ctx_dft = ctx and then have both normal and MTP passes write over the shared ctx logits. I think this minimizes required code changes elsewhere

This is my first time (1) working with ML stuff outside of python (2) attempting to contribute, so patience is appreciated :)

@ggerganov ggerganov added the hot Something that is hot label Aug 12, 2025
@ggerganov
Copy link
Member

AFAICT it looks like the server.cpp loop currently alternates between conventional forward pass and draft, which in the MTP case will probably sabotage performance gains (since our max throughput is only 1.5 tok/pass assuming zero rejections, instead of 2 tok/pass). Let me know if this isn't the case!—but if it is, should probably avoid doing non-speculative decodes after the first response token.

This is correct - we always alternate between conventional and speculative passes. It's definitely not optimal, but improves flexibility for regular sampling. It allows to change the speculative parameters and even disable it per request, while the logic is quite simple.

It should be possible to improve this by keeping track which slots are speculating on each iteration and skip adding tokens to the conventional batch for them. It might be a good idea to implement this separately to avoid huge changes in the logic in a single PR.

@ggerganov
Copy link
Member

Generally we should try to minimize the changes to llama.h, since changing/extending the public API requires a lot of effort.

On first look, I think the path that involves minimal changes is:

  • Add int n_mtp flag to llama_context_params (default = 1 - MTP is disabled, 2 - predict logits for one additional token, 3 - predict logits for 2 additional tokens, etc.)
  • Use this flag during graph build to determine if the MTP heads should be appended to the graph
  • Keep the conventional logits in the t_logits tensor in llm_graph_result
  • Add new tensor t_logits_mtp (or whatever is more appropriate) in llm_graph_result and use it to store the MTP results in it
  • In llama_decode() extract the t_logits_mtp data when available, following the same logic as for t_logits

Extracting the MTP logits during llama_decode() can be done in 2 ways:

  • Create separate buffer in the llama_context to store them and add a new llama_get_logits_mtp_ith() API that works with that new buffer in a similar way as the existing llama_get_logits_ith()
  • Reuse the existing logits buffer by expanding it to from [n_outputs][n_vocab] to [n_outputs][n_mtp*n_vocab]. This would avoid the need to add llama_get_logits_mtp_ith() and we can generalize the existing llama_get_logits_ith() by taking into account the value of n_mtp.

Currently, I am not sure which way is better. The first requires a new API call, while the second might break some existing assumptions (not sure if that's the case yet).

In any case, you can avoid this until you get the implementation working with a reasonable speedup. After that, we can discuss further how to best refactor the implementation.

@slaren
Copy link
Member

slaren commented Aug 13, 2025

Currently, I am not sure which way is better. The first requires a new API call, while the second might break some existing assumptions (not sure if that's the case yet).

I don't see an issue with adding a new API for this, and it would be easier to use.

@Juahyori
Copy link

Out of curiosity, is the API for this expected to be flexible enough that we could jump off of it to add things like Medusa / Eagle style (or IBM Accelerator) self speculative decoding heads?

I'm pretty sure they work fairly similarly (depending on the final output embeddings of the current token).

Another note:

After some consideration I think the expected speedup of the MTP module will depend a lot on the hardware the model's running on, particularly because it's an MoE model. While the next token prediction depends only on the current state, if we're doing self speculative decoding, that's additional forward passes. Those forward passes aren't guaranteed to have the same expert usage patterns, meaning the speedup should be some function of the tokens predicted and the expert re-use coefficient for the tokens verified.

So, just noting that if it's implemented and there's not a 2x or 3x increase in T/s, it may not be a skill issue on the part of a contributor, but due to the mathematical nature of the calculation.

For people running franken setups with Attention / KV Cache on GPU and MoE FFNs on CPU, it's possible that using previously unused experts in the verification sweep may result in a weird situation where the parallel verification process is actually memory bandwidth bound.

Not to discourage the implementation of this, I just wanted to give a heads up so nobody's dejected if the theoretical speedups can't be hit. There should still be at least some speedup, though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples hot Something that is hot server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants